from typing import List

import numpy as np
import onnxruntime as ort
import torch

import internutopia.core.util.math as math_utils
from internutopia.core.robot.articulation_action import ArticulationAction
from internutopia.core.robot.articulation_subset import ArticulationSubset
from internutopia.core.robot.controller import BaseController
from internutopia.core.robot.rigid_body import IRigidBody
from internutopia.core.robot.robot import BaseRobot
from internutopia.core.scene.scene import IScene
from internutopia_extension.configs.controllers import GR1MoveBySpeedControllerCfg


def load_onnx_policy(path: str):
    model = ort.InferenceSession(path)

    def run_inference(input_tensor):
        ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()}
        ort_outs = model.run(None, ort_inputs)
        return torch.tensor(ort_outs[0], device='cuda:0')

    return run_inference


default_upper_pose = np.zeros(54 - 12)
default_upper_pose[9] = -1.25
default_upper_pose[27] = -1.25


@BaseController.register('GR1MoveBySpeedController')
class GR1MoveBySpeedController(BaseController):
    """Controller class converting locomotion action to joint positions for GR1 robot."""

    def __init__(self, config: GR1MoveBySpeedControllerCfg, robot: BaseRobot, scene: IScene) -> None:
        super().__init__(config=config, robot=robot, scene=scene)
        print(f'config is {config.policy_weights_path}')
        self._policy = load_onnx_policy(path=config.policy_weights_path)
        self.joint_names = config.joint_names
        # Setup joint subset from joint names.
        self.joint_subset = ArticulationSubset(self.robot.articulation, self.joint_names)  # 27 joints

        self.policy_input_obs_num = 636
        self._last_target_joint_positions = np.zeros(54)
        self._old_policy_obs = torch.zeros(self.policy_input_obs_num)

        self._apply_times_left = (
            0  # Specifies how many times the action generated by the policy needs to be repeatedly applied.
        )

    def forward(
        self,
        forward_speed: float = 0.0,
        rotation_speed: float = 0.0,
        lateral_speed: float = 0.0,
        height: float = 0.9,
        target_upper_joint_pos=default_upper_pose,
    ) -> ArticulationAction:
        if self._apply_times_left > 0:
            self._apply_times_left -= 1
            if self.joint_subset is None:
                return ArticulationAction(joint_positions=self.applied_joint_positions)
            return self.joint_subset.make_articulation_action(
                joint_positions=self.applied_joint_positions, joint_velocities=None
            )

        # Get obs and infer.
        default_dof_pos = np.array(
            [
                0.0000,
                0.0000,
                -0.2618,
                0.5236,
                -0.2618,
                0.0000,
                -0.0000,
                0.0000,
                -0.2618,
                0.5236,
                -0.2618,
                0.0000,
                0.0000,
                0.0000,
                0.0000,
                0.0000,
                0.0000,
                0.0000,
                0.0000,
                0.2000,
                0.0000,
                -0.3000,
                0.0000,
                0.0000,
                0.0000,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0000,
                -0.2000,
                0.0000,
                -0.3000,
                0.0000,
                0.0000,
                0.0000,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
            ]
        )

        imu_link: IRigidBody = self.robot._imu_in_torso
        imu_pose_w = imu_link.get_pose()
        imu_quat_w = torch.tensor(imu_pose_w[1]).reshape(1, -1)
        imu_ang_vel_w = torch.tensor(imu_link.get_angular_velocity()[:]).reshape(1, -1)
        imu_ang_vel = np.array(math_utils.quat_rotate_inverse(imu_quat_w, imu_ang_vel_w).reshape(-1))

        projected_gravity = torch.tensor([[0.0, 0.0, -1.0]], device='cpu', dtype=torch.float)
        projected_gravity = np.array(math_utils.quat_rotate_inverse(imu_quat_w, projected_gravity).reshape(-1))

        joint_pos = self.joint_subset.get_joint_positions()
        joint_vel = self.joint_subset.get_joint_velocities()

        joint_pos -= default_dof_pos
        index = [
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11,
            12,
            13,
            14,
            15,
            16,
            17,
            18,
            19,
            20,
            21,
            22,
            23,
            24,
            36,
            37,
            38,
            39,
            40,
            41,
            42,
        ]

        tracking_command = np.array([forward_speed, lateral_speed, rotation_speed], dtype=np.float32)
        current_obs = np.concatenate(
            [
                tracking_command * np.array([2.0, 2.0, 0.25]),  # dim = 3
                np.array([height]),  # dim = 1
                imu_ang_vel * 0.25,  # dim = 3
                projected_gravity,  # dim = 3
                joint_pos[index],  # dim = 27
                joint_vel[index] * 0.05,  # dim = 27
                self._last_target_joint_positions[index],  # dim = 27
            ]
        )
        current_obs = np.clip(current_obs, -100.0, 100.0)

        policy_obs = torch.cat([self._old_policy_obs[106:], torch.Tensor(current_obs)])
        self._old_policy_obs = policy_obs
        policy_obs = policy_obs.reshape(1, self.policy_input_obs_num)

        target_lower_joint_positions = self._policy(policy_obs)[0].detach().cpu().numpy()  # dim = 12
        # print("action", target_lower_joint_positions)
        # target_lower_joint_positions[:] = -4.0
        self._last_target_joint_positions = np.clip(
            np.concatenate([target_lower_joint_positions, target_upper_joint_pos]), -100.0, 100.0
        )  # dim = 27
        gym_applied_joint_positions = self._last_target_joint_positions * 0.25 + default_dof_pos
        gym_applied_joint_positions[12:] = self._last_target_joint_positions[12:]
        self.applied_joint_positions = (
            gym_applied_joint_positions  # self.gym_adapter.gym2sim(gym_applied_joint_positions)
        )
        self._apply_times_left = 3

        return self.joint_subset.make_articulation_action(
            joint_positions=self.applied_joint_positions, joint_velocities=None
        )

    def action_to_control(self, action: List | np.ndarray) -> ArticulationAction:
        """Convert input action (in 1d array format) to joint positions to apply.

        Args:
            action (List | np.ndarray): 4-element 1d array containing:
              0. forward_speed (float)
              1. lateral_speed (float)
              2. rotation_speed (float)
              3. height (float)
              4. target_upper_joint_pos (np.ndarray)

        Returns:
            ArticulationAction: joint positions to apply.
        """
        assert len(action) > 2, 'action must contain at least 3 elements'

        height = 0.9
        if len(action) > 3 and action[3] is not None:
            height = action[3]

        target_upper_joint_pos = default_upper_pose
        if len(action) > 4 and action[4] is not None:
            target_upper_joint_pos = action[4]

        return self.forward(
            forward_speed=action[0],
            lateral_speed=action[1],
            rotation_speed=action[2],
            height=height,
            target_upper_joint_pos=target_upper_joint_pos,
        )
